package com.inyourcode.core.spring;

import com.inyourcode.agent.HotFixAgent;
import com.inyourcode.core.util.StackTraceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.UnmodifiableClassException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;

/**
 * @author JackLei
 */
public class ClassHelper {
    private static final Logger LOGGER = LoggerFactory.getLogger(ClassHelper.class);

    public static boolean redefineClass(ClassLoader classLoader, Map<String, byte[]>  bytesMap) {
        for (Map.Entry<String, byte[]> entry : bytesMap.entrySet()) {
            String className = entry.getKey();
            byte[] bytes = entry.getValue();
            try {
                LOGGER.info("Begin Class[{}] redefinition", className);
                ClassDefinition classDefinition = new ClassDefinition(classLoader.loadClass(className), bytes);
                HotFixAgent.getInst().redefineClasses(classDefinition);
                LOGGER.info("Class[{}] redefinition successed", className);
            } catch (ClassNotFoundException | UnmodifiableClassException e) {
                LOGGER.error("Class[{}] redefinition failed, exception:{}", className, StackTraceUtil.stackTrace(e));
                return false;
            }
        }

        return true;
    }

    public static Map<String, byte[]> readClassBytes(String jarName) throws IOException {
        return readClassBytes(jarName, new HashSet<>());
    }

    /**
     *
     * @param jarName
     * @param filterClazzName
     * @return
     * @throws IOException
     */
    public static Map<String, byte[]> readClassBytes(String jarName, Set<String> filterClazzName) throws IOException {
        Map<String, byte[]> clazzBytesMap = new HashMap<>();
        ClassPathResource classPathResource = new ClassPathResource(jarName);
        try (JarInputStream jis = new JarInputStream(classPathResource.getInputStream())) {

            for (JarEntry je; (je = jis.getNextJarEntry()) != null; ) {
                String name = je.getName();
                if (!name.endsWith(".class")) {
                    continue;
                }

                name = name.substring(0, name.lastIndexOf(".")).replaceAll(File.separator, ".");
                if (!filterClazzName.isEmpty() && !filterClazzName.contains(name)) {
                    continue;
                }

                int readBufferSize = 0;
                ByteArrayOutputStream bos = new ByteArrayOutputStream();
                byte[] buffer = new byte[1 << 14];
                while ((readBufferSize = jis.read(buffer)) != -1) {
                    bos.write(buffer, 0, readBufferSize);
                }

                clazzBytesMap.put(name, bos.toByteArray());
                bos.close();

                if (LOGGER.isInfoEnabled()) {
                    LOGGER.info("Read the byte array of class [{}]", name);
                }
            }
        }

        return clazzBytesMap;
    }

    public static List<Class> scan(String path, IClassScannerFilter filter) throws IOException, ClassNotFoundException, ClassScannerException {
        List<Class> classList = new ArrayList<>();
        ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        MetadataReaderFactory metaReader = new CachingMetadataReaderFactory();
        Resource[] resources = resolver.getResources("classpath*:" + path);
        ClassLoader loader = ClassLoader.getSystemClassLoader();
        for (Resource resource : resources) {
            MetadataReader reader = metaReader.getMetadataReader(resource);
            String className = reader.getClassMetadata().getClassName();
            Class clazz = loader.loadClass(className);
            if (filter.filter(clazz)) {
                classList.add(clazz);
            }
        }
        return classList;
    }

    public static interface IClassScannerFilter {
        IClassScannerFilter DEFAULT  = clazz -> {
            return true;
        };

        boolean filter(Class clazz) throws ClassScannerException;
    }

    public static class ClassScannerException extends Exception{
        public ClassScannerException() {
        }

        public ClassScannerException(String message) {
            super(message);
        }

        public ClassScannerException(String message, Throwable cause) {
            super(message, cause);
        }

        public ClassScannerException(Throwable cause) {
            super(cause);
        }

        public ClassScannerException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
            super(message, cause, enableSuppression, writableStackTrace);
        }
    }
}
